{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training Step: 2632  | total loss: \u001b[1m\u001b[32m2.34684\u001b[0m\u001b[0m | time: 4301.448s\n",
      "\u001b[2K\r",
      "| Adam | epoch: 001 | loss: 2.34684 -- iter: 0336896/1371994\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-1-1e8e306d07c2>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m     45\u001b[0m     \u001b[0mseed\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrandom_sequence_from_textfile\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmaxlen\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     46\u001b[0m     m.fit(X, Y, validation_set=0.1, batch_size=128,\n\u001b[0;32m---> 47\u001b[0;31m           n_epoch=1, run_id='shakespeare')\n\u001b[0m\u001b[1;32m     48\u001b[0m     \u001b[0;32mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"-- TESTING...\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     49\u001b[0m     \u001b[0;32mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"-- Test with temperature of 1.0 --\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/Users/gulli/miniconda2/envs/tensorflow/lib/python2.7/site-packages/tflearn/models/generator.pyc\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, X_inputs, Y_targets, n_epoch, validation_set, show_metric, batch_size, shuffle, snapshot_epoch, snapshot_step, excl_trainops, run_id)\u001b[0m\n\u001b[1;32m    172\u001b[0m                          \u001b[0mdaug_dict\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdaug_dict\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    173\u001b[0m                          \u001b[0mexcl_trainops\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mexcl_trainops\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 174\u001b[0;31m                          run_id=run_id)\n\u001b[0m\u001b[1;32m    175\u001b[0m         self.predictor = Evaluator([self.net],\n\u001b[1;32m    176\u001b[0m                                    session=self.trainer.session)\n",
      "\u001b[0;32m/Users/gulli/miniconda2/envs/tensorflow/lib/python2.7/site-packages/tflearn/helpers/trainer.pyc\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, feed_dicts, n_epoch, val_feed_dicts, show_metric, snapshot_step, snapshot_epoch, shuffle_all, dprep_dict, daug_dict, excl_trainops, run_id, callbacks)\u001b[0m\n\u001b[1;32m    337\u001b[0m                                                        \u001b[0;34m(\u001b[0m\u001b[0mbool\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbest_checkpoint_path\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m|\u001b[0m \u001b[0msnapshot_epoch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    338\u001b[0m                                                        \u001b[0msnapshot_step\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 339\u001b[0;31m                                                        show_metric)\n\u001b[0m\u001b[1;32m    340\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    341\u001b[0m                             \u001b[0;31m# Update training state\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/Users/gulli/miniconda2/envs/tensorflow/lib/python2.7/site-packages/tflearn/helpers/trainer.pyc\u001b[0m in \u001b[0;36m_train\u001b[0;34m(self, training_step, snapshot_epoch, snapshot_step, show_metric)\u001b[0m\n\u001b[1;32m    816\u001b[0m         \u001b[0mtflearn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_training\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msession\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msession\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    817\u001b[0m         _, train_summ_str = self.session.run([self.train, self.summ_op],\n\u001b[0;32m--> 818\u001b[0;31m                                              feed_batch)\n\u001b[0m\u001b[1;32m    819\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    820\u001b[0m         \u001b[0;31m# Retrieve loss value from summary string\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/Users/gulli/miniconda2/envs/tensorflow/lib/python2.7/site-packages/tensorflow/python/client/session.pyc\u001b[0m in \u001b[0;36mrun\u001b[0;34m(self, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m    893\u001b[0m     \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    894\u001b[0m       result = self._run(None, fetches, feed_dict, options_ptr,\n\u001b[0;32m--> 895\u001b[0;31m                          run_metadata_ptr)\n\u001b[0m\u001b[1;32m    896\u001b[0m       \u001b[0;32mif\u001b[0m \u001b[0mrun_metadata\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    897\u001b[0m         \u001b[0mproto_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtf_session\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTF_GetBuffer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrun_metadata_ptr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/Users/gulli/miniconda2/envs/tensorflow/lib/python2.7/site-packages/tensorflow/python/client/session.pyc\u001b[0m in \u001b[0;36m_run\u001b[0;34m(self, handle, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m   1122\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mfinal_fetches\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mfinal_targets\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mhandle\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mfeed_dict_tensor\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1123\u001b[0m       results = self._do_run(handle, final_targets, final_fetches,\n\u001b[0;32m-> 1124\u001b[0;31m                              feed_dict_tensor, options, run_metadata)\n\u001b[0m\u001b[1;32m   1125\u001b[0m     \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1126\u001b[0m       \u001b[0mresults\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/Users/gulli/miniconda2/envs/tensorflow/lib/python2.7/site-packages/tensorflow/python/client/session.pyc\u001b[0m in \u001b[0;36m_do_run\u001b[0;34m(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m   1319\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mhandle\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1320\u001b[0m       return self._do_call(_run_fn, self._session, feeds, fetches, targets,\n\u001b[0;32m-> 1321\u001b[0;31m                            options, run_metadata)\n\u001b[0m\u001b[1;32m   1322\u001b[0m     \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1323\u001b[0m       \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_do_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_prun_fn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_session\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhandle\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfeeds\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfetches\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/Users/gulli/miniconda2/envs/tensorflow/lib/python2.7/site-packages/tensorflow/python/client/session.pyc\u001b[0m in \u001b[0;36m_do_call\u001b[0;34m(self, fn, *args)\u001b[0m\n\u001b[1;32m   1325\u001b[0m   \u001b[0;32mdef\u001b[0m \u001b[0m_do_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1326\u001b[0m     \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1327\u001b[0;31m       \u001b[0;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1328\u001b[0m     \u001b[0;32mexcept\u001b[0m \u001b[0merrors\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mOpError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1329\u001b[0m       \u001b[0mmessage\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcompat\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mas_text\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmessage\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/Users/gulli/miniconda2/envs/tensorflow/lib/python2.7/site-packages/tensorflow/python/client/session.pyc\u001b[0m in \u001b[0;36m_run_fn\u001b[0;34m(session, feed_dict, fetch_list, target_list, options, run_metadata)\u001b[0m\n\u001b[1;32m   1304\u001b[0m           return tf_session.TF_Run(session, options,\n\u001b[1;32m   1305\u001b[0m                                    \u001b[0mfeed_dict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget_list\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1306\u001b[0;31m                                    status, run_metadata)\n\u001b[0m\u001b[1;32m   1307\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1308\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m_prun_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msession\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhandle\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "from __future__ import absolute_import, division, print_function\n",
    "\n",
    "import os\n",
    "import pickle\n",
    "from six.moves import urllib\n",
    "\n",
    "import tflearn\n",
    "from tflearn.data_utils import *\n",
    "\n",
    "path = \"shakespeare_input.txt\"\n",
    "char_idx_file = 'char_idx.pickle'\n",
    "\n",
    "if not os.path.isfile(path):\n",
    "    urllib.request.urlretrieve(\"https://raw.githubusercontent.com/tflearn/tflearn.github.io/master/resources/shakespeare_input.txt\", path)\n",
    "\n",
    "maxlen = 25\n",
    "\n",
    "char_idx = None\n",
    "if os.path.isfile(char_idx_file):\n",
    "  print('Loading previous char_idx')\n",
    "  char_idx = pickle.load(open(char_idx_file, 'rb'))\n",
    "\n",
    "X, Y, char_idx = \\\n",
    "    textfile_to_semi_redundant_sequences(path, seq_maxlen=maxlen, redun_step=3,\n",
    "                                         pre_defined_char_idx=char_idx)\n",
    "\n",
    "pickle.dump(char_idx, open(char_idx_file,'wb'))\n",
    "\n",
    "g = tflearn.input_data([None, maxlen, len(char_idx)])\n",
    "g = tflearn.lstm(g, 512, return_seq=True)\n",
    "g = tflearn.dropout(g, 0.5)\n",
    "g = tflearn.lstm(g, 512, return_seq=True)\n",
    "g = tflearn.dropout(g, 0.5)\n",
    "g = tflearn.lstm(g, 512)\n",
    "g = tflearn.dropout(g, 0.5)\n",
    "g = tflearn.fully_connected(g, len(char_idx), activation='softmax')\n",
    "g = tflearn.regression(g, optimizer='adam', loss='categorical_crossentropy',\n",
    "                       learning_rate=0.001)\n",
    "\n",
    "m = tflearn.SequenceGenerator(g, dictionary=char_idx,\n",
    "                              seq_maxlen=maxlen,\n",
    "                              clip_gradients=5.0,\n",
    "                              checkpoint_path='model_shakespeare')\n",
    "\n",
    "for i in range(50):\n",
    "    seed = random_sequence_from_textfile(path, maxlen)\n",
    "    m.fit(X, Y, validation_set=0.1, batch_size=128,\n",
    "          n_epoch=1, run_id='shakespeare')\n",
    "    print(\"-- TESTING...\")\n",
    "    print(\"-- Test with temperature of 1.0 --\")\n",
    "    print(m.generate(600, temperature=1.0, seq_seed=seed))\n",
    "    print(\"-- Test with temperature of 0.5 --\")\n",
    "    print(m.generate(600, temperature=0.5, seq_seed=seed))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
